#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 25 11:32:05 2025

@author: jcfq2
"""


import os

jwst_dir='/Users/jcfq2/data/observations/jwst'

os.chdir(jwst_dir)

import matplotlib.pyplot as plt
from astropy.visualization import make_lupton_rgb
from astropy.table import Table
import ch4_fiddlesticks as ch4
import pandas as pd
import h3ppy
import glob
import spiceypy as spice
from astropy.io import fits
import numpy as np
from importlib import reload
import JWSTSolarSystemPointing as jssp
from scipy.interpolate import interp1d



reload(jssp)

 
files = sorted(glob.glob(
    "/Users/jcfq2/data/observations/jwst/5308_dither_separated/*.fits"))
geo = jssp.JWSTSolarSystemPointing(files[13*4+0])
wave = geo.get_wavelength()
#


wavemin_F323 = 3.148
wavemax_F323 = 3.35
whw_F323 = np.argwhere((wave > wavemin_F323) & (wave < wavemax_F323)).flatten()


# could be problematic - lots going on in the background 
file_331='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_F323.txt'
bck_F323 = np.loadtxt(file_331)
f_F323 = interp1d(bck_F323[:,0], bck_F323[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
bck_F323_whw = f_F323(wave[whw_F323])




F323_values=np.loadtxt('F323N_filter.txt',delimiter=' ')

f_F323 = interp1d(F323_values[:,0], F323_values[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
filter_F323_sub = f_F323(wave[whw_F323])

filter_F323 = np.zeros_like(wave)
filter_F323[whw_F323]=filter_F323_sub
filter_F323_h3p = filter_F323*1.0
filter_F323_without_h3p = filter_F323*1.0

bck_F323_whw=filter_F323_sub*bck_F323_whw

h3p = h3ppy.h3p()
h3p.set(T=500, N=6e14, wave=wave, R=4700)
h3p_model=h3p.model()
h3p_model=h3p_model/np.max(h3p_model)

amagat = 2.76e25
h2 = h3ppy.h2()
h2.set(temperature = 500, density = amagat, R = 4700, wavelength = wave)
model_h2  = h2.model()
model_h2=model_h2/np.max(h3p_model)


filter_F323_h3p[h3p_model < 0.005]=0
filter_F323_without_h3p[h3p_model > 0.001]=0

filter_F323_h3p_mask=filter_F323_without_h3p+0
filter_F323_h3p_mask[filter_F323_h3p_mask != 0]=150
filter_F323_h3p_mask[filter_F323_h3p_mask == 0]=-150

filter_F323_h3p_mask2=filter_F323_h3p+0
filter_F323_h3p_mask2[filter_F323_h3p_mask2 != 0]=150
filter_F323_h3p_mask2[filter_F323_h3p_mask2 == 0]=-150


sy=3
ey=-2
geoim=geo.im[:, :,sy:ey]

ab=geoim[2239,:,:]+geoim[1568,:,:]+geoim[1642,:,:]+geoim[1010,:,:]+geoim[1011,:,:]+geoim[1018,:,:]+geoim[1019,:,:]
cd=(geoim[2241,:,:]*0.8*0.5+geoim[2236,:,:]*1.19*0.5)+((geoim[1571,:,:]+geoim[1565,:,:])*0.5)+(0.4*geoim[1640,:,:]/5.21)+(geoim[1004,:,:]+geoim[1005,:,:]*2)

img_h3p = ab-cd
 



reload(ch4)

ch4fit = ch4.fit_non_LTE_CH4_JWST(wave)
spec = geo.convert(wave, geoim[:,10,23])

fit = ch4fit.fit(wave, spec)





h3p_image = np.zeros_like(ab)
noh3p_image = np.zeros_like(ab)
full_image = np.zeros_like(ab)

for xx in range(ab[:,0].size):
    for yy in range(ab[0,:].size):
        pixel_F323=np.nansum(geoim[:,xx,yy]*filter_F323)
        pixel_F323_h3p=np.nansum(geoim[:,xx,yy]*filter_F323_h3p)
        pixel_F323_noh3p=np.nansum(geoim[:,xx,yy]*filter_F323_without_h3p)
        
        h3p_image[xx,yy] = pixel_F323_h3p
        noh3p_image[xx,yy] = pixel_F323_noh3p
        full_image[xx,yy] = pixel_F323




fig = plt.figure(figsize=(9,9))
a1 = plt.subplot(position=[0,2/3,1/4,1/3])
a2 = plt.subplot(position=[1/4,2/3,1/4,1/3])
a2.set_yticks(ticks=[0,5,10,15,20,25,30],labels=['','','','','','',''])
a3 = plt.subplot(position=[2/4,2/3,1/4,1/3])
a3.set_yticks(ticks=[0,5,10,15,20,25,30],labels=['','','','','','',''])
a4 = plt.subplot(position=[3/4,2/3,1/4,1/3])
a4.set_yticks(ticks=[0,5,10,15,20,25,30],labels=['','','','','','',''])
a5 = plt.subplot(position=[0,1/3,1,1/3])
a6 = plt.subplot(position=[0,0,1,1/3])

a3.imshow(h3p_image)
a4.imshow(noh3p_image)
a1.imshow(full_image,cmap='copper')
a2.imshow(img_h3p,cmap='Greens_r')



whw_F323 = whw_F323[80:200]
print(filter_F323_h3p_mask.shape)
filter_F323_h3p_mask=filter_F323_h3p_mask[whw_F323]
filter_F323_h3p_mask[0]=-150
filter_F323_h3p_mask[-1]=-150
filter_F323_h3p_mask2=filter_F323_h3p_mask2[whw_F323]
filter_F323_h3p_mask2[0]=-150
filter_F323_h3p_mask2[-1]=-150

print(filter_F323_h3p_mask.shape)

aur_spec=geoim[whw_F323,10,23]*filter_F323[whw_F323]
nonaur_spec=geoim[whw_F323,10,19]*filter_F323[whw_F323]

excess_ch4_hot = 1.3*1e-3*(ch4fit.ch4_hot[whw_F323]*ch4fit.hot_scaling)/(ch4fit.ch4_fun[whw_F323]*ch4fit.fun_scaling)*0.6*filter_F323[whw_F323]
modelled_h3p = h3p_model[whw_F323]/np.nanmax(h3p_model[whw_F323])*1.5*filter_F323[whw_F323]
modelled_h2 = model_h2[whw_F323]/np.nanmax(h3p_model[whw_F323])*1.5*filter_F323[whw_F323]



# a5.plot(wave[whw_F323],geoim[whw_F323,23,10]/np.nanmax(geoim[whw_F323,23,10]))
a5.plot(wave[whw_F323],aur_spec/np.max(aur_spec),label='Auroral spec')

# a5.plot(wave[whw_F323],geoim[whw_F323,23,10]*filter_F323[whw_F323]/(np.nanmax(geoim[whw_F323,23,10]*filter_F323[whw_F323])))
# a5.plot(wave[whw_F323],h3p_model[whw_F323]/np.nanmax(h3p_model[whw_F323]))
# a6.plot(wave[whw_F323],geoim[whw_F323,10,30]/np.nanmax(geoim[whw_F323,10,30]))
# a6.plot(wave[whw_F323],geoim[whw_F323,10,30]*filter_F323[whw_F323]/(np.nanmax(geoim[whw_F323,10,30]*filter_F323[whw_F323])))
a5.plot(wave[whw_F323],nonaur_spec/np.max(aur_spec),label='Non-auroral spec')
a6.plot(wave[whw_F323],modelled_h3p,label='Estimated H3+ spec',color='r',linestyle='dotted')
# a6.plot(wave[whw_F323],modelled_h2,label='Estimated H3+ spec',color='r',linestyle='dotted')
# a5.plot(wave[whw_F323],bck_F323_whw/np.nanmax(bck_F323_whw)*7+10)

a6.plot(wave[whw_F323],(aur_spec-nonaur_spec)/np.max(aur_spec),label='Difference in spectra',color='purple')
a6.plot(wave[whw_F323],excess_ch4_hot,linestyle='dotted',label='CH4 hot/fund')
a6.plot(wave[whw_F323],excess_ch4_hot+h3p_model[whw_F323]/np.nanmax(h3p_model[whw_F323])*1.5*filter_F323[whw_F323],linestyle='dashed',c='m',label='CH4 hot + H3+')
a6.fill(wave[whw_F323],filter_F323_h3p_mask,'b',alpha=0.05)
a6.fill(wave[whw_F323],filter_F323_h3p_mask2,'r',alpha=0.05)
a6.set_ylim(-0.15,0.2)

a5.plot(wave[whw_F323],1e5*(ch4fit.ch4_fun[whw_F323]*ch4fit.fun_scaling+ch4fit.ch4_hot[whw_F323]*0.45*ch4fit.hot_scaling)*filter_F323[whw_F323],linestyle='dotted',label='Estimated CH4 fluor')
a5.legend()
a6.legend()

# a6.plot(ch4fit.ch4_fun[whw_F323]*ch4fit.fun_scaling*filter_F323[whw_F323])
# a6.plot(ch4fit.ch4_hot[whw_F323]*ch4fit.hot_scaling*0.45*filter_F323[whw_F323])

fig.savefig('asd_supplementary_fig1.pdf', dpi=300, bbox_inches='tight', facecolor='white', pad_inches=0) 

plt.show()

print(np.nansum(excess_ch4_hot)/np.nansum(aur_spec/np.max(aur_spec))*100)
print(np.nansum(modelled_h3p)/np.nansum(aur_spec/np.max(aur_spec))*100)
print((np.nansum(aur_spec-nonaur_spec)/np.max(aur_spec)-np.nansum(modelled_h3p+excess_ch4_hot))/np.nansum(aur_spec/np.max(aur_spec))*100)
print(np.nansum(aur_spec/np.max(aur_spec)))